Skip to content

Gemma4: Full DFlash Integration (Speculative Decode + BSA Prefill + Prefix Cache)#232

Merged
davide221 merged 19 commits into
Luce-Org:mainfrom
howard0su:gemma4
May 22, 2026
Merged

Gemma4: Full DFlash Integration (Speculative Decode + BSA Prefill + Prefix Cache)#232
davide221 merged 19 commits into
Luce-Org:mainfrom
howard0su:gemma4

Conversation

@howard0su
Copy link
Copy Markdown
Contributor

@howard0su howard0su commented May 20, 2026

Adds complete Gemma4 support to the DFlash inference pipeline — speculative decode, BSA sparse-FA prefill, SWA ring-buffer KV cache, and prefix cache with correct feature snapshot/restore.

Key Changes

Speculative Decode (DFlash target)

  • Implement Gemma4DFlashTarget adapter (embed, verify_batch, project_hidden)
  • Wire spec-decode into Gemma4Backend with feature mirror sync
  • Replace snapshot/replay with KV truncation for ~2× fewer forward passes per spec step
  • Fix SWA causal masking and rope_theta for draft model accuracy

BSA Sparse-FA Prefill

  • Implement gemma4_prefill_bsa() — per-layer BSA for SWA layers, dense FA for full-attention layers
  • Add unified flash_prefill_forward() dispatch (BF16→F16→Q8) in flashprefill.h
  • Simplify Qwen3's prefill dispatch to use the same unified function (−68 lines)

Prefix Cache

  • Add snapshot_save calls in generate()/restore_and_generate() (was completely missing — cache hits returned 0 tokens)
  • Save/restore target_feat + last_tok in snapshots (matching Qwen35) to maintain spec-decode acceptance rate after restore
  • Full feature mirror resync after restore to ensure draft model consistency
  • Add Gemma family detection for chat-turn boundary markers

SWA Ring-Buffer & Architecture

  • Implement G5 SWA ring-buffer indexing with proper causal masks
  • Add fa_window support for full-attention layers (G6)
  • Handle variable head_dim (128 SWA / 256 full), per-layer n_head_kv, KV sharing
  • Fix attention scale, tokenizer decode, and server integration

Misc

  • Fix draft loader to read rope_theta from GGUF (remove hardcoded constant)
  • Rename draft_dflash_graph.cpp → draft_graph.cpp
  • Add graph_compute error checking in BSA prefill path

@howard0su howard0su marked this pull request as ready for review May 20, 2026 12:18
Copy link
Copy Markdown
Contributor

@cubic-dev-ai cubic-dev-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 issues found across 25 files

Reply with feedback, questions, or to request a fix.

Re-trigger cubic

Comment thread dflash/src/draft/draft_gguf_loader.cpp
Comment thread dflash/src/gemma4/gemma4_loader.cpp
Comment thread dflash/src/server/tokenizer.cpp
@dusterbloom
Copy link
Copy Markdown
Contributor

Thanks for the PR @howard0su .
I tested it locally on gemma-4-31B-it-Q4_K_M and gemma-4-26B-A4B-it-UD-Q4_K_M.

Gemma4 never actually enters the BSA sparse-FA path right now. Every layer falls through the guard and uses dense WMMA instead.

Is that expected for this PR?

In case it is, no worries. It can be added later.

Only if needed here something to consider:
dflash/src/gemma4/gemma4_graph.cpp:549 does abs_k % swa_size with no guard.

Gemma4Cache::swa_size default-initialises to 0 (gemma4_internal.h:161); current call paths through create_gemma4_cache are safe (swa_size falls back to max_ctx), but any future path that constructs a Gemma4Cache outside that helper, or a GGUF where sliding_window reads as 0 while swa_layers[il] is true, would SIGFPE rather than fail cleanly.

Cheap fix:

GGML_ASSERT(swa_size > 0 && "SWA branch entered with uninitialised cache.swa_size");
      const int slot = abs_k % swa_size;

Both findings reproduced against PR head de0e4c1.

@howard0su
Copy link
Copy Markdown
Contributor Author

good catch. I will check it. I only have a 2080ti which doesn't support BSA. Will try to find a env to debug.

@howard0su howard0su changed the title Make Gemma4 full feature with dflash/pflash Make Gemma4 full feature parity with dflash/pflash May 20, 2026
Copy link
Copy Markdown
Contributor

@cubic-dev-ai cubic-dev-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 issue found across 4 files (changes from recent commits).

Reply with feedback, questions, or to request a fix.

Re-trigger cubic

Comment thread dflash/src/gemma4/gemma4_graph.cpp Outdated
@howard0su howard0su changed the title Make Gemma4 full feature parity with dflash/pflash Gemma4: Full DFlash Integration (Speculative Decode + BSA Prefill + Prefix Cache) May 21, 2026
howard0su and others added 17 commits May 22, 2026 07:16
Loader fixes:
- Handle array-typed metadata (head_count_kv is per-layer array)
- Fallback n_vocab from token_embd.weight tensor shape
- Default missing keys (expert_count, etc.) to 0
- Separate head_dim_full (512) and head_dim_swa (256)
- Per-layer n_head_kv_per_layer vector from GGUF array
- SWA pattern: read bool/uint8 array or infer from head_kv
- Tied embeddings: output = tok_embd when output.weight absent
- Tensor name mapping: post_attention_norm, post_ffw_norm,
  layer_output_scale
- Global rope_freqs_global tensor support

Graph fixes:
- Per-layer head_dim and n_head_kv via helper functions
- FA mask padding to 256 (FATTN_KQ_STRIDE) for CUDA compat
- Use global rope_freqs for full-attn layers

Cache:
- Per-layer KV allocation with correct dimensions

Validated: load + prefill + decode + snapshot + restore all pass
on gemma-4-31B-it-Q4_K_M.gguf (RTX 2080 Ti).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add Gemma4DFlashTarget class implementing the DFlashTarget interface:
- verify_batch: full forward with all-token argmax via gemma4_verify_batch
- snapshot_kv / restore_kv: full KV cache save/restore for rollback
- embed_tokens: CPU embedder with sqrt(n_embd) scaling
- project_hidden_to_tokens: lm_head projection via gemma4_project_hidden
- capture_layer_ids: evenly-spaced 5 layers (1, 15, 29, 43, 57)
- mask_token_id: 0 (padding token)

New graph functions:
- gemma4_verify_batch(): like gemma4_step but returns all-position argmax
- gemma4_project_hidden(): out_norm + lm_head + softcap + argmax

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Critical fixes for Gemma4 model inference:

- Fix kq_scale: Gemma4 uses self.scaling=1.0 (not 1/sqrt(head_dim))
  because Q/K already get per-head RMS norm. This was the root cause
  of garbage output (repeated token generation).

- Add SentencePiece tokenizer support: Gemma4 tokens are raw UTF-8
  with U+2581 for space, not GPT-2 byte-level encoding. Detects mode
  from tokenizer.ggml.model GGUF key. Handles encode (space->▁, UTF-8
  char splitting) and decode (▁->space) correctly.

- Fix KV cache layout: [D, max_ctx, Hk] matching Qwen35 convention,
  with per-head strided snapshot save/restore.

- Add Gemma4 chat template: <bos><|turn>user\n...<turn|>\n<|turn>model\n

- Map Gemma4 thinking channel (<|channel>...<channel|>) to existing
  <think>...</think> reasoning system for proper content separation.

- Add eos_chat_id detection for <turn|> token (id 106).

- Fix special token filtering in both streaming and non-streaming paths.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
park() now frees snapshots, KV cache, and model weights (releasing GPU
memory). unpark() reloads weights from disk and recreates the KV cache.

Also adds parked guards to generate(), restore_and_generate(), and
snapshot_save() to prevent use while model is parked.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
G5: SWA layers now allocate min(sliding_window, max_ctx) KV cache
instead of full max_ctx. Ring-buffer write (kv_start % swa_size) and
ring-aware attention mask enable bounded memory for sliding-window
layers. Prefill chunks are capped to avoid ring wrap.

G6: Added fa_window config for sparse decode. Full-attention layers
limit their FA read to the last fa_window positions during decode,
reducing compute at long contexts.

G3: Ported PFlash compress pipeline from Qwen35. Parks target,
lazy-loads Qwen3-0.6B drafter, runs score_and_compress, emits
surviving tokens, unparks. Drafter stays resident (~1.4 GB).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add target_feat ring buffer to Gemma4Cache for feature capture
- Add feature capture nodes to build_gemma4_layer() (both step and verify)
- Add draft model loading with metadata override (GGUF has wrong dimensions)
- Infer n_capture_layers from fc weight shape (6 for Gemma4, not 5 from metadata)
- Port do_spec_decode() loop from qwen35 backend
- Wire spec-decode into generate() and restore_and_generate() (temp==0 only)
- Sync captured features to DraftFeatureMirror after each prefill chunk
- Store last_tok during prefill for spec-decode entry
- Pass draft_path/draft_gpu/draft_ctx_max through BackendArgs to Gemma4BackendConfig
- Clean up draft resources in shutdown()

Tested: AR decode produces correct output, spec-decode pipeline runs
end-to-end with 9.1 tok/s throughput.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Gemma4 uses <|turn> / <turn|> as single-token turn delimiters.
Previously it incorrectly fell through to the Laguna family check
because <system>/<user>/etc. would encode to non-empty sequences.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Three root-cause fixes identified from the HuggingFace model card
(z-lab/gemma-4-31B-it-DFlash config.json):

1. mask_token_id: use 4 instead of 0 — the draft model was trained
   with token 4 as the mask/padding token.

2. capture_layer_ids: replace integer-truncation formula with
   floating-point linspace + rounding. For 60 layers / 6 captures:
   old: {1,12,23,34,45,56}, correct: {1,12,23,35,46,57}.

3. embed_tokens: remove sqrt(n_embd) scaling — the draft model
   expects raw unscaled embeddings (same as qwen35 convention).

Also removes debug fprintf statements added during investigation.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add causal attention mask for SWA layers in the draft model (layers 0-3
  are sliding-window with causal masking, layer 4 is full non-causal).
  The draft was trained this way; running all-non-causal let future MASK
  embeddings leak into earlier positions, hurting acceptance rate.
- Read rope_theta from draft GGUF metadata instead of hardcoded 10M constant
  (Gemma4 draft uses 1M, not 10M like Qwen3.5).
- Remove double-normalization: gemma4_project_hidden now skips out_norm since
  the draft already applies its own final norm layer.
- Scale embed_tokens by sqrt(n_embd) in DFlashTarget to match Gemma4 convention.
- Set swa_window=2048 and mark layers[0..3].is_swa after draft GGUF loading.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Match the pattern from attn_masks.h — create the causal mask tensor
as GGML_TYPE_F16 directly and fill with uint16_t values (0x0000 for
attend, 0xFC00 for -inf). This eliminates the intermediate ggml_cast
op in the draft graph and reduces memory usage.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The implementation file now matches its header (draft_graph.h), eliminating
confusion with the similarly-named common/dflash_draft_graph.cpp orchestrator.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The 10M default was Qwen3.5-specific and silently wrong for other models
(e.g. Gemma4 uses 1M). Now rope_theta must come from the draft GGUF
metadata; a warning is printed if the key is missing.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Gemma4 is a pure transformer — after verify, KV entries at accepted
positions are already correct (causal masking guarantees independence
from rejected tokens). Replace the expensive snapshot → verify →
restore → replay pattern with:

  verify(16 tokens) → truncate KV → bonus(1 token)

This eliminates:
- 2x full KV cache copies (60 layers × K + V each direction)
- The replay forward pass (~9 tokens through 60 layers)

Measured ~2.2x speedup on RTX 2080 Ti (9.5 → 21 tok/s).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…d dispatch

- Implement gemma4_prefill_bsa() for per-layer BSA prefill using
  flash_prefill_forward for SWA layers (head_dim=128) with dense FA
  fallback for full-attention layers (head_dim=256).
- Write KV cache during Graph A (ring-buffer aware for SWA layers).
- Add GGML_ASSERT guard for swa_size > 0 before modulo operation.
- Add flash_prefill_forward() unified dispatch to flashprefill.h that
  selects bf16/f16/q8 kernel based on compile flags + buffer type.
- Simplify Qwen3 attention dispatch to use the unified function.
- Remove duplicated ifdef boilerplate from both model implementations.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
After restoring KV from a snapshot, do_prefill only syncs the feature mirror
for the delta tokens [snap_pos..committed). The positions [0..snap_pos) in
the mirror retain stale data from the previous request's decode phase (which
may have diverged from the current prompt context after the ring buffer
wraps).

Fix: call draft_feature_mirror_sync_tail after restore to resync the entire
[0..committed) feature range from cache_.target_feat to the mirror. This
ensures the draft model sees consistent features and maintains high
acceptance rate (AL) during speculative decoding.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

gemma4: save/restore target_feat in prefix cache snapshot

Matching Qwen35's approach: save target_feat (BF16 feature ring buffer) and
last_tok as part of the KV snapshot. On restore, target_feat is copied back
to GPU before the delta prefill + feature mirror resync.

Previously, only K/V tensors were snapshotted. After restore, the feature
mirror contained stale data from the previous request's decode phase, causing
the draft model to make poor predictions and halving speculative decode
acceptance rate (52% → 24%).

With this fix, the full feature state is correctly restored, and the
subsequent draft_feature_mirror_sync_tail ensures the mirror matches.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Davide Cifarelli and others added 2 commits May 22, 2026 11:06
Three fixes for gemma-4-26B-A4B-it (unsloth UD-Q4_K_M).

1. gemma4_graph.cpp:116 — GGML_ASSERT(ggml_is_contiguous(src0))
   crash in ggml_cuda_op_gelu. gate_e and up_e are strided
   ggml_view_3d halves of fused gate_up_e; CUDA gelu requires
   contiguous src. Insert ggml_cont before ggml_gelu.

2. gemma4_loader.cpp tensor name mismatches with actual GGUF
   metadata (silently loaded null → MoE produced gibberish):
     ffn_gate_inp_shexp.weight → ffn_gate_inp.scale
     ffn_down_exps_s.weight    → ffn_down_exps.scale
     ffn_pre_norm_2.weight     → pre_ffw_norm_2.weight
     ffn_post_norm_1.weight    → post_ffw_norm_1.weight
     ffn_post_norm_2.weight    → post_ffw_norm_2.weight

3. leading_dense_block_count default 1 → 0. Gemma-4-26B-A4B GGUF
   does not store this key; old default skipped MoE on layer 0,
   running shared-expert only and corrupting downstream.

Verified: 'What is 2+2?' returns '2 + 2 = 4' on lucebox2 RTX 3090.

Co-Authored-By: WOZCODE <contact@withwoz.com>
# Conflicts:
#	dflash/src/gemma4/gemma4_backend.h
@davide221 davide221 merged commit 6467da5 into Luce-Org:main May 22, 2026
2 of 3 checks passed
pull Bot pushed a commit to HSKIMRobert/lucebox-hub that referenced this pull request May 22, 2026
PR Luce-Org#236 (placement refactor) replaced BackendArgs::draft_gpu with
BackendArgs::draft_device but only updated the qwen35 caller. PR Luce-Org#232
(gemma4 DFlash spec decode) merged after Luce-Org#236 and re-introduced
args.draft_gpu in the gemma4 branch, breaking compilation of
dflash_common on main.

Caught by PR Luce-Org#252 CI build.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants